import typing
from typing import *

import torch
from torch import Tensor

from torch_geometric.typing import *
from torch_geometric import is_compiling
from torch_geometric.utils import is_sparse

from {{module}} import *


class CollectArgs(NamedTuple):
{%- for name, type_hint in collect_types.items() %}
    {{name}}: {{type_hint}}
{%- endfor %}


def _collect(
    self,
    edge_index: Union[Tensor, SparseTensor],
{%- for name, type_hint in propagate_types.items() %}
    {{name}}: {{type_hint}},
{%- endfor %}
    size: List[Optional[int]],
) -> CollectArgs:

    i, j = (1, 0) if self.flow == 'source_to_target' else (0, 1)
{% for name in collect_types %}
{%- if name.endswith('_i') or name.endswith('_j') %}
    # Collect `{{name}}`:
    if isinstance({{name[:-2]}}, (tuple, list)):
        assert len({{name[:-2]}}) == 2
        _{{name[:-2]}}_0, _{{name[:-2]}}_1 = {{name[:-2]}}[0], {{name[:-2]}}[1]
        if isinstance(_{{name[:-2]}}_0, Tensor):
            self._set_size(size, 0, _{{name[:-2]}}_0)
{%- if name.endswith('_j') %}
            {{name}} = self._lift(_{{name[:-2]}}_0, edge_index, {{name[-1]}})
        else:
            {{name}} = None
{%- endif %}
        if isinstance(_{{name[:-2]}}_1, Tensor):
            self._set_size(size, 1, _{{name[:-2]}}_1)
{%- if name.endswith('_i') %}
            {{name}} = self._lift(_{{name[:-2]}}_1, edge_index, {{name[-1]}})
        else:
            {{name}} = None
{%- endif %}
    elif isinstance({{name[:-2]}}, Tensor):
        self._set_size(size, 0, {{name[:-2]}})
        self._set_size(size, 1, {{name[:-2]}})
        {{name}} = self._lift({{name[:-2]}}, edge_index, {{name[-1]}})
    else:
        {{name}} = None
{%- endif %}
{%- endfor %}

    assert isinstance(edge_index, Tensor)
    adj_t = None
    edge_index_i = edge_index[i]
    edge_index_j = edge_index[j]
    ptr = None
    index = edge_index_i

    size_i = size[i]
    size_j = size[j]
    dim_size = size_i

    return CollectArgs({% for name in collect_types %}{{name}}{% if not loop.last %}, {% endif %}{% endfor %})


def propagate(
    self,
    edge_index: Union[Tensor, SparseTensor],
{%- for name, type_hint in propagate_types.items() %}
    {{name}}: {{type_hint}},
{%- endfor %}
    size: Size = None,
) -> {{propagate_return_type}}:

    decomposed_layers = 1 if self.explain is True else self.decomposed_layers

    if not torch.jit.is_scripting() and not is_compiling():
        for hook in self._propagate_forward_pre_hooks.values():
            kwargs = dict({% for name in propagate_types %}{{name}}={{name}}{% if not loop.last %}, {% endif %}{% endfor %})
            res = hook(self, (edge_index, size, kwargs))
            if res is not None:
                edge_index, size, kwargs = res
{%- for name, type_hint in propagate_types.items() %}
                {{name}} = kwargs['{{name}}']
{%- endfor %}

    mutable_size = self._check_input(edge_index, size)
    fuse = is_sparse(edge_index) and self.fuse and self.explain is not True

    if fuse:
        raise NotImplementedError

    else:
        kwargs = self._collect(edge_index, {% for name in propagate_types %}{{name}}, {% endfor %}mutable_size)
        out = self.message({% for name in message_args %}{{name}}=kwargs.{{name}}{% if not loop.last %}, {% endif %}{% endfor %})
        out = self.aggregate(out{% for name in aggregate_args %}, {{name}}=kwargs.{{name}}{% endfor %})
        out = self.update(out{% for name in update_args %}, {{name}}=kwargs.{{name}}{% endfor %})
        return out
